Source code for hysop.backend.device.opencl.opencl_kernel_autotuner

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from hysop.tools.htypes import check_instance
from hysop.tools.units import bytes2str
from hysop.tools.misc import prod
from hysop.backend.device.opencl import cl, cl_api
from hysop.backend.device.kernel_autotuner import KernelAutotuner, KernelGenerationError
from hysop.backend.device.opencl.opencl_autotunable_kernel import (
    OpenClAutotunableKernel,
)
from hysop.backend.device.opencl.opencl_kernel_statistics import OpenClKernelStatistics


[docs] class OpenClKernelAutotuner(KernelAutotuner): def __init__(self, name, tunable_kernel, **kwds): super().__init__(name=name, tunable_kernel=tunable_kernel, **kwds) check_instance(tunable_kernel, OpenClAutotunableKernel) self.cl_env = tunable_kernel.cl_env self.typegen = tunable_kernel.typegen
[docs] def autotuner_config_key(self): """Caching key for autotuner results.""" return ( self.typegen.__repr__(), self.cl_env.platform.name.strip(), self.cl_env.device.name.strip(), self.build_opts, )
def _print_header(self, *args, **kwds): cl_env = self.cl_env verbose = super()._print_header(*args, **kwds) if verbose: print(f" *platform: {cl_env.platform.name.strip()}") print(f" *device: {cl_env.device.name.strip()}")
[docs] def collect_kernel_infos(self, tkernel, extra_parameters, extra_kwds): """ Collect kernel infos before computing workload and work group size. """ kernel_name, kernel_src = tkernel.generate_kernel_src( global_work_size=None, local_work_size=None, extra_parameters=extra_parameters, extra_kwds=extra_kwds, tuning_mode=False, dry_run=True, ) prg, kernel = self.build_from_source( kernel_name=kernel_name, kernel_src=kernel_src, build_options=[], force_verbose=False, force_debug=False, ) check_instance(prg, cl.Program) check_instance(kernel, cl.Kernel) kwgi = cl.kernel_work_group_info max_kernel_wg_size = kernel.get_work_group_info( kwgi.WORK_GROUP_SIZE, self.cl_env.device ) preferred_work_group_size_multiple = kernel.get_work_group_info( kwgi.PREFERRED_WORK_GROUP_SIZE_MULTIPLE, self.cl_env.device ) return (max_kernel_wg_size, preferred_work_group_size_multiple)
[docs] def check_kernel(self, tkernel, kernel, global_work_size, local_work_size): check_instance(kernel, cl.Kernel) cl_env = self.cl_env device = cl_env.device lmem = cl.characterize.usable_local_mem_size(device) kwgi = cl.kernel_work_group_info max_kernel_wg_size = kernel.get_work_group_info(kwgi.WORK_GROUP_SIZE, device) kernel_local_mem_size = kernel.get_work_group_info(kwgi.LOCAL_MEM_SIZE, device) wgs = prod(local_work_size) if wgs > max_kernel_wg_size: msg = "Work group size {} exceeds maximum kernel work group size {} for kernel {}." msg = msg.format(wgs, max_kernel_wg_size, kernel.function_name) raise RuntimeError(msg) if kernel_local_mem_size > lmem: msg = "Maximum usable device local memory size {} exceeded for kernel {} which " msg += "needs {}." msg = msg.format( bytes2str(lmem), kernel.function_name, bytes2str(kernel_local_mem_size) ) raise RuntimeError(msg) cl_version = tuple(map(int, cl_env.cl_version)) if cl_version >= (1, 2) and (device.type == cl.device_type.CUSTOM): max_kernel_global_size = kernel.get_work_group_info( kwgi.GLOBAL_WORK_SIZE, device ) if np.any(np.greater(global_work_size, max_kernel_global_size)): msg = "Global size {} exceeded for kernel {} which allows only {} for " msg += "the custom device {}." msg = msg.format( global_size, kernel.function_name, max_kernel_global_size, device.name, ) raise RuntimeError(msg) if cl.characterize.has_struct_arg_count_bug(device, ctx=cl_env.context): msg = "Device has struct argument counting bug." raise RuntimeError(msg)
[docs] def check_kernel_args(self, kernel, args_list): for i, arg in enumerate(args_list): if isinstance(arg, cl.MemoryObject) and (arg.context != kernel.context): msg = ( "OpenCL kernel buffer argument {} context differs from the kernel " ) msg += "context for kernel {}." msg = msg.format(i, kernel.function_name) raise RuntimeError(msg)
[docs] def build_from_source( self, kernel_name, kernel_src, build_options, force_verbose, force_debug ): """ Compile and bench one kernel by executing it nruns times. Return a AutotunerKernelStatistics instance. """ prg = self.cl_env.build_raw_src( src=kernel_src, build_options=build_options, kernel_name=kernel_name, force_verbose=force_verbose, force_debug=force_debug, ) kernels = prg.all_kernels() assert len(kernels) == 1 kernel = kernels[0] return (prg, kernel)
[docs] def bench_one_from_binary( self, kernel, target_nruns, global_work_size, local_work_size, old_stats, best_stats, ): cl_env = self.cl_env ctx, device = cl_env.context, cl_env.device profiling_enable = cl.command_queue_properties.PROFILING_ENABLE global_size = tuple(global_work_size) local_size = tuple(local_work_size) assert target_nruns >= 1 if old_stats is None: stats = OpenClKernelStatistics() else: stats = old_stats pruned = False try: with cl.CommandQueue(ctx, device, properties=profiling_enable) as queue: assert queue.properties & profiling_enable while (stats.nruns < target_nruns) and (not pruned): try: evt = cl.enqueue_nd_range_kernel( queue, kernel, global_size, local_size ) except cl_api.RuntimeError: raise KernelGenerationError() evt.wait() stats += OpenClKernelStatistics(events=[evt]) if best_stats is not None: pruned = ( stats.mean > self.autotuner_config.prune_threshold * best_stats.mean ) except Exception as e: print(e) print() msg = "\nFATAL ERROR: Failed to bench kernel global_work_size={}, local_work_size={}" msg = msg.format(global_work_size, local_work_size) # try to dump offending kernel msg += "\nTrying to dump source kernel to '/tmp/hysop_kernel_dump.cl'...\n" print(msg) with open("/tmp/hysop_kernel_dump.cl", "w") as f: f.write(kernel.program.source) msg += "\n" raise return (stats, pruned)